import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl

class PhysicsAwareGATv2Conv(nn.Module):
    """
    A physics-aware variant of the GATv2 convolutional layer.

    This layer extends the standard Graph Attention (GATv2) mechanism by optionally
    incorporating edge features into the attention computation. It computes per-head
    attention scores and aggregates neighbor features accordingly, with an optional
    residual connection. The output dimension per head is `out_feats`, and there are
    `num_heads` parallel attention heads.

    Arguments:
        in_feats (int): Dimension of input node features.
        out_feats (int): Dimension of output features per head.
        num_heads (int): Number of attention heads.
        feat_drop (float): Dropout rate on input features before projection.
        attn_drop (float): Dropout rate on attention scores.
        negative_slope (float): Slope for LeakyReLU in attention computation.
        residual (bool): Whether to add a residual connection from input to output.
        activation (callable or None): Activation function applied after aggregation.
        allow_zero_in_degree (bool): If False, raises an error when nodes have zero in-degree.
        use_edge_features (bool): If True, project edge features and include them in attention.
    """

    def __init__(self, in_feats, out_feats, num_heads=1, feat_drop=0.0,
                 attn_drop=0.0, negative_slope=0.2, residual=False, activation=None,
                 allow_zero_in_degree=True, use_edge_features=True):
        super(PhysicsAwareGATv2Conv, self).__init__()

        # Store hyperparameters
        self.in_feats = in_feats
        self.out_feats = out_feats
        self.num_heads = num_heads
        self.use_edge_features = use_edge_features

        # Linear projections for source and target node features.
        # We project input features to (out_feats * num_heads), then reshape.
        self.fc_src = nn.Linear(in_feats, out_feats * num_heads, bias=False)
        self.fc_dst = nn.Linear(in_feats, out_feats * num_heads, bias=False)

        # If using edge features, also project them to (out_feats * num_heads)
        if use_edge_features:
            self.fc_edge = nn.Linear(in_feats, out_feats * num_heads, bias=False)

        # Attention weight vectors for left (source) and right (target) projections.
        # Shape: (1, num_heads, out_feats). We will broadcast these over the batch of edges.
        self.attn_l = nn.Parameter(torch.FloatTensor(size=(1, num_heads, out_feats)))
        self.attn_r = nn.Parameter(torch.FloatTensor(size=(1, num_heads, out_feats)))

        # Dropout modules for features and attention scores
        self.feat_drop = nn.Dropout(feat_drop)
        self.attn_drop = nn.Dropout(attn_drop)

        # Optional activation (e.g., F.elu) applied after aggregation
        self.activation = activation
        # LeakyReLU negative slope for attention coefficient computation
        self.negative_slope = negative_slope

        # If residual=True, add a skip connection from input to output
        self.residual = residual
        if residual:
            # If input and output dimensions mismatch, project input to match
            if in_feats != out_feats * num_heads:
                self.res_fc = nn.Linear(in_feats, out_feats * num_heads, bias=False)
            else:
                self.res_fc = nn.Identity()

        # If any node has zero in-degree and allow_zero_in_degree=False, raise error
        self.allow_zero_in_degree = allow_zero_in_degree

        # Initialize all weights
        self.reset_parameters()

    def reset_parameters(self):
        """
        Initialize all learnable parameters with Xavier (Glorot) normal initialization
        for weights and leave biases at zero.
        """
        gain = nn.init.calculate_gain('relu')
        nn.init.xavier_normal_(self.fc_src.weight, gain=gain)
        nn.init.xavier_normal_(self.fc_dst.weight, gain=gain)
        if self.use_edge_features:
            nn.init.xavier_normal_(self.fc_edge.weight, gain=gain)

        nn.init.xavier_normal_(self.attn_l, gain=gain)
        nn.init.xavier_normal_(self.attn_r, gain=gain)

    def forward(self, graph, feat, edge_feat=None, get_attention=False):
        """
        Forward pass of the PhysicsAwareGATv2Conv layer.

        Arguments:
            graph (DGLGraph): A DGLGraph object containing the connectivity. The graph
                              should have been constructed so that src and dst fields
                              correspond to the desired direction of message passing.
            feat (Tensor): Node feature matrix of shape (num_nodes, in_feats).
            edge_feat (Tensor or None): Edge feature matrix of shape (num_edges, in_feats).
            get_attention (bool): If True, also return the raw attention scores for each edge.

        Returns:
            If get_attention=False:
                h_out (Tensor): Node embeddings of shape (num_nodes, num_heads, out_feats).
            If get_attention=True:
                (h_out, attn_scores) where
                - h_out is as above,
                - attn_scores is a Tensor of shape (num_edges, num_heads, 1)
                  representing the unnormalized attention logits per head and edge.
        """
        with graph.local_scope():
            # If zero in-degree nodes exist and not allowed, throw an error
            if not self.allow_zero_in_degree:
                zero_indegree = (graph.in_degrees() == 0)
                if zero_indegree.any():
                    raise RuntimeError("Zero in-degree nodes detected. Add self-loops.")

            num_nodes = graph.num_nodes()

            # 1. Apply dropout to input node features before projection
            h_src = self.feat_drop(feat)  # shape: (num_nodes, in_feats)
            h_dst = self.feat_drop(feat)  # same shape, used for target-side projection

            # 2. Project source and target node features via separate linear layers
            #    and reshape to (num_nodes, num_heads, out_feats)
            h_src_proj = self.fc_src(h_src).view(num_nodes, self.num_heads, self.out_feats)
            h_dst_proj = self.fc_dst(h_dst).view(num_nodes, self.num_heads, self.out_feats)

            # Store projected features for message passing
            graph.srcdata['ft'] = h_src_proj  # features for source nodes
            graph.dstdata['ft'] = h_dst_proj  # features for destination nodes

            # 3. Compute attention coefficient "prefixes" for each node:
            #    el = sum_over_features( h_src_proj * attn_l ), shape: (num_nodes, num_heads, 1)
            #    er = sum_over_features( h_dst_proj * attn_r ), shape: (num_nodes, num_heads, 1)
            el = (h_src_proj * self.attn_l).sum(dim=-1).unsqueeze(-1)
            er = (h_dst_proj * self.attn_r).sum(dim=-1).unsqueeze(-1)

            # Store the left/right components in node data so we can combine on edges
            graph.srcdata['el'] = el
            graph.dstdata['er'] = er

            # 4. If edge features are provided and enabled, project them similarly
            if self.use_edge_features and edge_feat is not None:
                edge_feat_dropped = self.feat_drop(edge_feat)  # dropout on edge inputs
                edge_feat_proj = self.fc_edge(edge_feat_dropped).view(-1, self.num_heads, self.out_feats)
                graph.edata['edge_feat'] = edge_feat_proj

            # 5. For each edge (u→v), compute a_base = el[u] + er[v]
            #    This sums the contributions from source and destination nodes
            graph.apply_edges(lambda edges: {'a_base': edges.src['el'] + edges.dst['er']})

            # 6. If using edge features, also add their contribution to the attention logits
            if self.use_edge_features and 'edge_feat' in graph.edata:
                # Compute (edge_feat × attn_l) summed over feature dims → shape: (num_edges, num_heads, 1)
                edge_attn = (graph.edata['edge_feat'] * self.attn_l).sum(dim=-1).unsqueeze(-1)
                graph.edata['a_base'] += edge_attn

            # 7. Apply LeakyReLU to get unnormalized attention logits
            graph.edata['a_base'] = F.leaky_relu(graph.edata['a_base'], self.negative_slope)

            # 8. Apply dropout to the attention logits
            graph.edata['a'] = self.attn_drop(graph.edata['a_base'])

            # 9. Perform message passing:
            #    - message_func sends {'ft': projected_src_feats, 'a': attention_logits} from each src node
            #    - reduce_func computes normalized softmax over incoming edges and aggregates features
            graph.update_all(self.message_func, self.reduce_func)

            # 10. After aggregation, we have stored the updated node features in 'h'
            h_out = graph.dstdata['h']  # shape: (num_nodes, num_heads, out_feats)

            # 11. Apply activation if provided
            if self.activation:
                h_out = self.activation(h_out)

            # 12. If residual connection is enabled, add the (possibly projected) original input
            if self.residual:
                res = self.res_fc(feat).view(num_nodes, self.num_heads, self.out_feats)
                h_out = h_out + res

            # Return either just the new features or also the raw attention logits
            if get_attention:
                return h_out, graph.edata['a']  # shape of 'a': (num_edges, num_heads, 1)
            else:
                return h_out

    def message_func(self, edges):
        """
        Message function for DGL update_all.
        Passes projected source features 'ft' and raw attention 'a' along each edge.
        """
        return {'ft': edges.src['ft'], 'a': edges.data['a']}

    def reduce_func(self, nodes):
        """
        Reduce function for DGL update_all. For each receiving node:
          1. Retrieve the incoming attention logits (nodes.mailbox['a']), shape: (batch_size, num_neighbors, num_heads, 1)
          2. Do a numerically stable softmax over the neighbors dimension:
               alpha = exp(a - max(a)) / sum(exp(a - max(a)))
          3. Weight sum the incoming 'ft' features (nodes.mailbox['ft']) by alpha:
               h = sum(alpha * ft)
          4. Store the aggregated feature in nodes.data['h'].

        Returns:
            dict with key 'h': Tensor of shape (batch_size, num_heads, out_feats)
        """
        a = nodes.mailbox['a']  # shape: (batch_size, num_in_edges, num_heads, 1)
        # Compute max over neighbor dimension for numerical stability: shape (batch_size, 1, num_heads, 1)
        max_a = torch.max(a, dim=1, keepdim=True)[0]
        exp_a = torch.exp(a - max_a)   # subtract max to stabilize
        alpha = exp_a / torch.sum(exp_a, dim=1, keepdim=True)  # normalize over neighbors
        # Weighted sum of projected source features: shape (batch_size, num_heads, out_feats)
        h = torch.sum(alpha * nodes.mailbox['ft'], dim=1)
        return {'h': h}


class PhysicsAwareGATv2(nn.Module):
    """
    A two-layer physics-aware GATv2 network that produces a pair of outputs per edge:
      - diag_out ∈ ℝ^m: a 1-dimensional signal per edge (e.g., main diagonal correction)
      - low_out ∈ ℝ^{m×r}: a low-rank embedding per edge (r = low_rank_dim)

    The architecture:
      - Layer 1: PhysicsAwareGATv2Conv(in_feats → hidden_size, num_heads)
      - Layer 2: PhysicsAwareGATv2Conv(hidden_size * num_heads → 1 + low_rank_dim, 1 head)
      - Output activation (e.g., sigmoid) applied elementwise
      - diag_out extracted as the first component of the final embedding
      - low_out as the remaining r components.

    Arguments:
        in_feats (int): Dimension of input node features.
        hidden_size (int): Output dimension per head in the first GAT layer.
        n_iter (int): Number of "iterations" or epochs (passed to optimizer but not used internally here).
        lr (float): Learning rate for the Adam optimizer.
        num_heads (int): Number of heads for the first GAT layer.
        low_rank_dim (int): Dimension r of the low-rank output embedding per edge.
        early_stop (int): Patience for early stopping (not used internally here).
        output_activation (callable): Activation to apply to the final raw output.
        use_edge_features (bool): Whether to pass edge features to the conv layers.
    """

    def __init__(self, in_feats, hidden_size, n_iter, lr, num_heads=1,
                 low_rank_dim=4, early_stop=10, output_activation=torch.sigmoid,
                 use_edge_features=True):
        super(PhysicsAwareGATv2, self).__init__()

        self.low_rank_dim = low_rank_dim
        self.num_heads = num_heads
        self.n_iter = n_iter
        self.early_stop = early_stop
        self.output_activation = output_activation

        # First GATv2 layer: projects in_feats → hidden_size per head, with num_heads
        self.gatv2_1 = PhysicsAwareGATv2Conv(
            in_feats=in_feats,
            out_feats=hidden_size,
            num_heads=num_heads,
            feat_drop=0.1,
            attn_drop=0.1,
            negative_slope=0.2,
            residual=True,
            activation=F.elu,
            allow_zero_in_degree=True,
            use_edge_features=use_edge_features
        )

        # Second GATv2 layer: projects (hidden_size * num_heads) → (1 + low_rank_dim), with 1 head
        self.gatv2_2 = PhysicsAwareGATv2Conv(
            in_feats=hidden_size * num_heads,
            out_feats=(1 + low_rank_dim),
            num_heads=1,
            feat_drop=0.1,
            attn_drop=0.1,
            negative_slope=0.2,
            residual=True,
            activation=None,  # No activation; apply output_activation later
            allow_zero_in_degree=True,
            use_edge_features=use_edge_features
        )

        # Adam optimizer over all parameters (GAT layers)
        self.optimizer = torch.optim.Adam(self.parameters(), lr=lr)
        self.use_cuda = torch.cuda.is_available()
        if self.use_cuda:
            self.cuda()

    def forward(self, g, inputs, edge_features=None, return_attention=False):
        """
        Forward pass producing edge-level outputs.

        Steps:
          1. Pass node features through layer 1:
               h1 = gatv2_1(g, inputs, edge_features)
             Optionally retrieve raw attention logits from this layer if return_attention=True.
          2. Reshape h1 from (num_nodes, num_heads, hidden_size) to (num_nodes, hidden_size * num_heads).
          3. Pass through layer 2:
               h2 = gatv2_2(g, h1_flattened, edge_features)
             Optionally retrieve raw attention logits from layer 2.
          4. h2 has shape (num_nodes, 1 + low_rank_dim, 1) per node, squeeze the singleton head dimension.
          5. Apply output_activation to each of the (1 + low_rank_dim) values per node.
          6. Interpret:
               diag_out = h2[:, 0]           # main diagonal correction per edge
               low_out  = h2[:, 1:]          # low_rank_dim embeddings per edge
          7. If return_attention=True, compute a combined attention per edge by averaging
             the two layers’ raw attentions, then return (diag_out, low_out, combined_attn).
             Otherwise, return (diag_out, low_out).

        Arguments:
            g (DGLGraph): The graph (with n edges). The edges in g correspond to exactly those
                          edges we want an output for.
            inputs (Tensor): Node feature matrix of shape (num_nodes, in_feats).
            edge_features (Tensor or None): Edge feature matrix of shape (num_edges, in_feats).
            return_attention (bool): Whether to also return a dict mapping edge-id → attention score.

        Returns:
            If return_attention=False:
                (diag_out, low_out) where
                  - diag_out is Tensor (num_edges,) containing the first dimension of output per edge.
                  - low_out is Tensor (num_edges, low_rank_dim) containing the remaining dims.
            If return_attention=True:
                (diag_out, low_out, combined_attn) where combined_attn is a dict:
                  key: edge ID (0..num_edges−1), value: average of attention from layer1 and layer2.
        """
        # Layer 1 forward
        if return_attention:
            # Retrieve both embeddings and raw attention logits from layer1
            h1, attn1 = self.gatv2_1(g, inputs, edge_features, get_attention=True)
        else:
            h1 = self.gatv2_1(g, inputs, edge_features)

        # h1 shape: (num_nodes, num_heads, hidden_size). Flatten the head dimension
        h1_flat = h1.view(h1.shape[0], -1)  # (num_nodes, hidden_size * num_heads)

        # Layer 2 forward
        if return_attention:
            h2, attn2 = self.gatv2_2(g, h1_flat, edge_features, get_attention=True)
        else:
            h2 = self.gatv2_2(g, h1_flat, edge_features)

        # h2 shape: (num_nodes, 1 + low_rank_dim, 1). Squeeze the head dimension
        h2 = h2.squeeze(1)  # (num_nodes, 1 + low_rank_dim)
        # Apply the final activation (e.g., sigmoid) elementwise
        h2 = self.output_activation(h2)

        # Extract diag_out (first column) and low_out (remaining columns)
        diag_out = h2[:, 0]      # (num_nodes,)
        low_out = h2[:, 1:]      # (num_nodes, low_rank_dim)

        if return_attention:
            # Combine layer1 and layer2 attentions by averaging per edge
            edge_ids = g.edges(form='eid')  # Tensor of shape (num_edges,)
            # attn1 and attn2 each have shape (num_edges, num_heads, 1)
            combined_attn = {
                eid.item(): (attn1[i].mean().item() + attn2[i].mean().item()) / 2.0
                for i, eid in enumerate(edge_ids)
            }
            return diag_out, low_out, combined_attn
        else:
            return diag_out, low_out
